{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "79c178bc47ac1b2f",
   "metadata": {},
   "source": [
    "# Quickstart\n",
    "\n",
    "Start with `Flower Datasets` as fast as possible by learning the essentials."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0f34a29f74b13cb",
   "metadata": {},
   "source": [
    "## Install Flower Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install -q \"flwr-datasets[vision]\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f19a191a",
   "metadata": {},
   "source": [
    "If you want to use audio datasets install:\n",
    "\n",
    "```bash\n",
    "! pip install -q \"flwr-datasets[audio]\"\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "499dd2f0d23d871e",
   "metadata": {},
   "source": [
    "## Choose the dataset\n",
    "\n",
    "To choose the dataset, go to Hugging Face [Datasets Hub](https://huggingface.co/datasets) and search for your dataset by name. You will pass that names to the `dataset` parameter of `FederatedDataset`. Note that the name is case-sensitive.\n",
    "\n",
    "<div style=\"max-width:80%; margin-left: auto; margin-right: auto;\">\n",
    "  <img src=\"./_static/tutorial-quickstart/choose-hf-dataset.png\" alt=\"Choose HF dataset.\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9d449e6",
   "metadata": {},
   "source": [
    "Note that once the dataset is available on HuggingFace Hub it can be immediately used in `Flower Datasets` (no approval from Flower team is needed, no custom code needed). \n",
    "\n",
    "Here is how it looks for `CIFAR10` dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7d66b23efb1a289",
   "metadata": {},
   "source": [
    "<div style=\"max-width:80%; margin-left: auto; margin-right: auto;\">\n",
    "  <img src=\"./_static/tutorial-quickstart/copy-dataset-name.png\" alt=\"Choose HF dataset.\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0c146753048fb2a",
   "metadata": {},
   "source": [
    "## Partition the dataset\n",
    "\n",
    "To partition a dataset (in a basic scenario), you need to choose two things:\n",
    "1) A dataset (identified by a name),\n",
    "2) A partitioning scheme (by selecting one of the supported partitioning schemes, [see all of them here](https://flower.ai/docs/datasets/ref-api/flwr_datasets.partitioner.html), or creating a custom partitioning scheme).\n",
    "\n",
    "\n",
    "\n",
    "**1) Dataset choice**\n",
    "\n",
    "We will pass the name of the dataset to `FederatedDataset(dataset=\"some-name\", other-parameters)`. In this example it will be: `FederatedDataset(dataset=\"uoft-cs/cifar10\", other-parameters)`\n",
    "\n",
    "**2) Partitioner choice**\n",
    "\n",
    "We will partition the dataset in an IID manner using `IidPartitioner` ([link to the docs](https://flower.ai/docs/datasets/ref-api/flwr_datasets.partitioner.IidPartitioner.html#flwr_datasets.partitioner.IidPartitioner)). \n",
    "Only the train split of the dataset will be processed. In general, we do `FederatedDataset(dataset=\"some-name\", partitioners={\"split-name\": partitioning_scheme})`, which for this example looks like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a759c5b6f25c9dd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from flwr_datasets import FederatedDataset\n",
    "from flwr_datasets.partitioner import IidPartitioner\n",
    "\n",
    "fds = FederatedDataset(\n",
    "    dataset=\"uoft-cs/cifar10\", partitioners={\"train\": IidPartitioner(num_partitions=10)}\n",
    ")\n",
    "\n",
    "# Load the first partition of the \"train\" split\n",
    "partition = fds.load_partition(0, \"train\")\n",
    "# You can access the whole \"test\" split of the base dataset (it hasn't been partitioned)\n",
    "centralized_dataset = fds.load_split(\"test\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de75d15c3f5b2383",
   "metadata": {},
   "source": [
    "Now we have 10 partitions created from the train split of the CIFAR10 dataset and the test split\n",
    "for the centralized evaluation. Later we will convert the type of the dataset from Hugging Face's `Dataset` type to the format required by PyTorch/TensorFlow frameworks."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efa7dbb120505f1f",
   "metadata": {},
   "source": [
    "## Investigate the partition"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf986a1a9f0284cd",
   "metadata": {},
   "source": [
    "### Features\n",
    "\n",
    "Now we will determine the names of the features of your dataset (you can alternatively do that directly on the Hugging Face\n",
    "website). The names can vary along different datasets e.g. \"img\" or \"image\", \"label\" or \"labels\". Additionally, if the label column is of [ClassLabel](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.ClassLabel) type, we will also see the names of labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7ff7cecdda8a931",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img': Image(mode=None, decode=True, id=None),\n",
       " 'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], id=None)}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Note this dataset has\n",
    "partition.features"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e69ed05193a098a",
   "metadata": {},
   "source": [
    "### Indexing\n",
    "\n",
    "To see the first sample of the partition, we can index it like a Python list."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f2097d4c5121a1b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,\n",
       " 'label': 1}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "partition[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a10ad2b97c4dd92a",
   "metadata": {},
   "source": [
    "Then we can additionally choose the specific column."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa7f0e2e29841f54",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "partition[0][\"label\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fe1cef9a121dbc5",
   "metadata": {},
   "source": [
    "We can also use slicing (take a few samples). Let's take the first 3 samples of the first partition:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "779818b365682c60",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,\n",
       "  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,\n",
       "  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>],\n",
       " 'label': [1, 2, 6]}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "partition[:3]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a354aa36fc586438",
   "metadata": {},
   "source": [
    "We get a dictionary where the keys are the names of the columns and the values are list of the corresponding values of each row of the dataset. So to take the first 3 labels we can do:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25fca62a8f2fbe51",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1, 2, 6]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "partition[:3][\"label\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e4790671ffe2142",
   "metadata": {},
   "source": [
    "Note that the indexing by column first is also possible but discouraged because the whole column will be loaded into the memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7836fe6d65c673b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1, 2, 6]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "partition[\"label\"][:3]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3c46099625437fc",
   "metadata": {},
   "source": [
    "You can also select a subset of the dataset and keep the same type (dataset.Dataset) instead of receiving a dictionary of values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "708abab74de3d5a1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['img', 'label'],\n",
       "    num_rows: 3\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "partition.select([0, 1, 2])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "462f707b4f078a8d",
   "metadata": {},
   "source": [
    "And this dataset contains the same samples as we saw before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19d2e3cc74d93c4d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,\n",
       "  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,\n",
       "  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>],\n",
       " 'label': [1, 2, 6]}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "partition.select([0, 1, 2])[:]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5e683cfaddf92f",
   "metadata": {},
   "source": [
    "## Use with PyTorch/NumPy/TensorFlow\n",
    "\n",
    "For more detailed instructions, go to:\n",
    "\n",
    "* [how-to-use-with-pytorch](https://flower.ai/docs/datasets/how-to-use-with-pytorch.html)\n",
    "\n",
    "* [how-to-use-with-numpy](https://flower.ai/docs/datasets/how-to-use-with-numpy.html)\n",
    "\n",
    "* [how-to-use-with-tensorflow](https://flower.ai/docs/datasets/how-to-use-with-tensorflow.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de14f09f0ee4f6ac",
   "metadata": {},
   "source": [
    "### PyTorch\n",
    "\n",
    "Transform the `Dataset` into the `DataLoader`, use the `PyTorch transforms` (`Compose` and all the others are possible)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a94321ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install -q torch torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "544c0e73054f3445",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from torchvision.transforms import ToTensor\n",
    "\n",
    "transforms = ToTensor()\n",
    "\n",
    "\n",
    "def apply_transforms(batch):\n",
    "    # For CIFAR-10 the \"img\" column contains the images we want to apply the transforms to\n",
    "    batch[\"img\"] = [transforms(img) for img in batch[\"img\"]]\n",
    "    return batch\n",
    "\n",
    "\n",
    "partition_torch = partition.with_transform(apply_transforms)\n",
    "dataloader = DataLoader(partition_torch, batch_size=64)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b93678a5",
   "metadata": {},
   "source": [
    "The `Dataloader` created this way does not return a `Tuple` when iterating over it but a `Dict` with the names of the columns as keys and features as values. Look below for an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5edd3ce2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Return type when iterating over dataloader: <class 'dict'>\n",
      "torch.Size([64, 3, 32, 32])\n",
      "torch.Size([64])\n"
     ]
    }
   ],
   "source": [
    "for batch in dataloader:\n",
    "    print(f\"Return type when iterating over a dataloader: {type(batch)}\")\n",
    "    print(batch[\"img\"].shape)\n",
    "    print(batch[\"label\"].shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71531613",
   "metadata": {},
   "source": [
    "### NumPy\n",
    "\n",
    "NumPy can be used as input to the TensorFlow and scikit-learn models. The transformation is very simple."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b98b3e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "partition_np = partition.with_format(\"numpy\")\n",
    "X_train, y_train = partition_np[\"img\"], partition_np[\"label\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4867834",
   "metadata": {},
   "source": [
    "### TensorFlow Dataset\n",
    "\n",
    "Transformation to TensorFlow Dataset is a one-liner."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a69ce677",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install -q tensorflow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db86f1aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_dataset = partition.to_tf_dataset(\n",
    "    columns=\"img\", label_cols=\"label\", batch_size=64, shuffle=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61fd797c",
   "metadata": {},
   "source": [
    "## Final remarks\n",
    "\n",
    "Congratulations, you now know the basics of Flower Datasets and are ready to perform basic dataset preparation for Federated Learning."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbdfe1b5",
   "metadata": {},
   "source": [
    "## Next \n",
    "\n",
    "This is the first quickstart tutorial from the Flower Datasets series. See other tutorials:\n",
    "\n",
    "* [Use Partitioners](https://flower.ai/docs/datasets/tutorial-use-partitioners.html)\n",
    "\n",
    "* [Visualize Label Distribution](https://flower.ai/docs/datasets/tutorial-visualize-label-distribution.html)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "flwr",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
